{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Experiment AAMAS\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. [Reproducing the Braess paradox](#braess_paradox)\n",
    "2. [Computation time of algorithms to compute Nash equibrium in N-player and mean field games as a function of the number of players](#efficiency)\n",
    "3. [Sioux Falls, 14,000 vehicles with MFG](#sioux_falls)\n",
    "4. [Augmented Braess network with multiple origin destinations](#multiple_destinations)\n",
    "5. [Average deviation of the mean field equilibrium policy in the N-player Pigou network game as a function of N](#pigou_deviation)\n",
    "6. [Average deviation of the mean field equilibrium policy in the N-player Braess network game as a function of N](#braess_deviation)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 0. Importing libraries\n",
    "If the import does not work please download and compile open spiel from source and check if you have all the required libraries."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from open_spiel.python import policy as policy_module\n",
    "from open_spiel.python.algorithms import best_response as best_response_module\n",
    "from open_spiel.python.algorithms import expected_game_score\n",
    "from open_spiel.python.games import dynamic_routing_to_mean_field_game\n",
    "from open_spiel.python.mfg.algorithms import distribution as distribution_module\n",
    "from open_spiel.python.mfg.algorithms import nash_conv as nash_conv_module\n",
    "from open_spiel.python.mfg.algorithms import policy_value\n",
    "from open_spiel.python.mfg.games import dynamic_routing as mean_field_routing_game\n",
    "\n",
    "from utils import *"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a name='braess_paradox'></a>\n",
    "\n",
    "## 1. Reproducing the Braess paradox with the mean field routing game\n",
    "\n",
    "This is used to produce figure 1 of the AAMAS article."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BRAESS_NUM_VEHICLES = 4\n",
    "BRAESS_ORIGIN = 'A->B'\n",
    "BRAESS_DESTINATION = 'E->F'\n",
    "BRAESS_TIME_STEP_LENGTH = 0.25\n",
    "BRAESS_MAX_TIME_STEP = int(4.0/BRAESS_TIME_STEP_LENGTH) + 1\n",
    "\n",
    "BRAESS_GRAPH = create_braess_network(BRAESS_NUM_VEHICLES)\n",
    "plot_network_n_player_game(BRAESS_GRAPH)\n",
    "\n",
    "BRAESS_GAME, BRAESS_SEQ_GAME, BRAESS_MFG_GAME = create_games(\n",
    "    BRAESS_ORIGIN, BRAESS_DESTINATION, BRAESS_NUM_VEHICLES, BRAESS_GRAPH, BRAESS_MAX_TIME_STEP,\n",
    "    BRAESS_TIME_STEP_LENGTH)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Online Mirror Descent\n",
    "\n",
    "md_p_init = mirror_descent.MirrorDescent(BRAESS_MFG_GAME, lr=1)\n",
    "mfmd_timing, mfmd_policy, mfmd_nash_conv, mfmd_policy_value, md_p = online_mirror_descent(\n",
    "    BRAESS_MFG_GAME, 10, compute_metrics=True, return_policy=True, md_p=md_p_init)\n",
    "evolve_mean_field_game(BRAESS_MFG_GAME, mfmd_policy, BRAESS_GRAPH)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a name='efficiency'></a>\n",
    "## 2. Computation time of algorithms to compute Nash equibrium in N-player and mean field games as a function of the number of players.\n",
    "\n",
    "This is used to produce figure 2 of the AAMAS article.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "timing_n_player_results = {}\n",
    "timing_mean_field_results = {}\n",
    "NUM_ALGO_ITERATIONS = 10\n",
    "\n",
    "for num_vehicles in range(5, 45, 5):\n",
    "  braess_game, braess_seq_game, braess_mfg_game = create_games(\n",
    "      BRAESS_ORIGIN, BRAESS_DESTINATION, num_vehicles, BRAESS_GRAPH, BRAESS_MAX_TIME_STEP,\n",
    "      BRAESS_TIME_STEP_LENGTH)\n",
    "  ext_cfr_timing, ext_cfr_policy = external_sampling_monte_carlo_counterfactual_regret_minimization(braess_seq_game, NUM_ALGO_ITERATIONS)\n",
    "  mfmd_timing, mfmd_policy = online_mirror_descent(braess_mfg_game, NUM_ALGO_ITERATIONS, compute_metrics=False)\n",
    "  timing_n_player_results[num_vehicles] = ext_cfr_timing\n",
    "  timing_mean_field_results[num_vehicles] = mfmd_timing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(list(timing_mean_field_results), list(timing_mean_field_results.values()), '-o', label=f'{NUM_ALGO_ITERATIONS} iterations of MFG OMD')\n",
    "plt.plot(list(timing_n_player_results), list(timing_n_player_results.values()), '--xr', label=f'{NUM_ALGO_ITERATIONS} iterations of N-player CFR')\n",
    "plt.legend()\n",
    "plt.yscale('log')\n",
    "plt.xlabel('Number of players')\n",
    "plt.ylabel('Computation time')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a name='sioux_falls'></a>\n",
    "## 3. Solving large games with mean field online mirror descent algorithm: 14,000 vehicles in the Sioux Falls network\n",
    "\n",
    "This is used to produce figure 4 and 5 of the AAMAS article.\n",
    "Depending on the computer used, the computation can take a long time. On the MacBook Pro 2019 with macOS Big Sur 11.6 it tooks around 10 hours.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "SIOUX_FALLS_GRAPH, SIOUX_FALLS_OD_DEMAND = create_sioux_falls_network()\n",
    "plot_network_n_player_game(SIOUX_FALLS_GRAPH)\n",
    "\n",
    "SIOUX_FALLS_OD_DEMAND = [\n",
    "    dynamic_routing_utils.OriginDestinationDemand(f'bef_19->19', f'1->aft_1', 0, 7000),\n",
    "    dynamic_routing_utils.OriginDestinationDemand(f'bef_1->1', f'19->aft_19', 0, 7000)\n",
    "]\n",
    "\n",
    "SIOUX_FALLS_TIME_STEP_LENGTH = 0.5  # 0.2\n",
    "SIOUX_FALLS_MAX_TIME_STEP = int(40.0/SIOUX_FALLS_TIME_STEP_LENGTH) + 1  # 0.25\n",
    "\n",
    "SIOUX_MFG_GAME = mean_field_routing_game.MeanFieldRoutingGame(\n",
    "    {\"max_num_time_step\": SIOUX_FALLS_MAX_TIME_STEP, \"time_step_length\": SIOUX_FALLS_TIME_STEP_LENGTH},\n",
    "    network=SIOUX_FALLS_GRAPH, od_demand=SIOUX_FALLS_OD_DEMAND)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def online_mirror_descent_sioux_falls(mfg_game,\n",
    "                                      number_of_iterations,\n",
    "                                      md_p=None):\n",
    "  nash_conv_dict = {}\n",
    "  md = md_p if md_p else mirror_descent.MirrorDescent(mfg_game)\n",
    "  tick_time = time.time()\n",
    "  for i in range(number_of_iterations):\n",
    "    if i < 32:\n",
    "      md.iteration(learning_rate=1)\n",
    "    elif i < 64:\n",
    "      md.iteration(learning_rate=0.1)\n",
    "    else:\n",
    "      md.iteration(learning_rate=0.01)\n",
    "    md_policy = md.get_policy()\n",
    "    nash_conv_md = nash_conv_module.NashConv(mfg_game, md_policy)\n",
    "    nash_conv_dict[i] = nash_conv_md.nash_conv()\n",
    "    print((f\"Iteration {i}, Nash conv: {nash_conv_md.nash_conv()}, \"\n",
    "           f\"time: {time.time() - tick_time}\"))\n",
    "  timing = time.time() - tick_time\n",
    "  md_policy = md.get_policy()\n",
    "  distribution_mfg = distribution_module.DistributionPolicy(mfg_game, md_policy)\n",
    "  policy_value_ = policy_value.PolicyValue(\n",
    "      mfg_game, distribution_mfg, md_policy).value(mfg_game.new_initial_state())\n",
    "  nash_conv_md = nash_conv_module.NashConv(mfg_game, md_policy)\n",
    "  return timing, md_policy, nash_conv_md, policy_value_, md, nash_conv_dict\n",
    "\n",
    "md_p_init = mirror_descent.MirrorDescent(SIOUX_MFG_GAME, lr=1)\n",
    "mfmd_timing, mfmd_policy, mfmd_nash_conv, mfmd_policy_value, md_p, nash_conv_dict = online_mirror_descent_sioux_falls(\n",
    "    SIOUX_MFG_GAME, 100, md_p=md_p_init)\n",
    "\n",
    "print(f\"Online mirror descent nash conv: {mfmd_nash_conv.nash_conv()}\")\n",
    "print(f\"Online mirror descent timing: {mfmd_timing}\")\n",
    "\n",
    "tick_time = time.time()\n",
    "evolve_mean_field_game(SIOUX_MFG_GAME, mfmd_policy, SIOUX_FALLS_GRAPH)\n",
    "print(time.time() - tick_time)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(list(nash_conv_dict), list(nash_conv_dict.values()), 'x') #, label='Online mirror descent')\n",
    "plt.legend()\n",
    "plt.xlabel('Number of iterations')\n",
    "plt.ylabel('Average deviation incentive')\n",
    "plt.show()\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a name='multiple_destinations'></a>\n",
    "## 4. Augmented Braess network with multiple origin destinations.\n",
    "\n",
    "This is used to produce figure 7 of the AAMAS article."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "AUG_BRAESS_GRAPH = create_augmented_braess_network(150)\n",
    "plot_network_n_player_game(AUG_BRAESS_GRAPH)\n",
    "\n",
    "AUG_BRAESS_OD_DEMAND = [\n",
    "    dynamic_routing_utils.OriginDestinationDemand('A->B', 'E->F', 0, 50),\n",
    "    dynamic_routing_utils.OriginDestinationDemand('A->B', 'E->F', 0.5, 50),\n",
    "    dynamic_routing_utils.OriginDestinationDemand('A->B', 'E->F', 1, 50),\n",
    "    dynamic_routing_utils.OriginDestinationDemand('A->B', 'D->G', 0, 50),\n",
    "    dynamic_routing_utils.OriginDestinationDemand('A->B', 'D->G', 1, 50)]\n",
    "\n",
    "AUG_BRAESS_TIME_STEP_LENGTH = 0.05\n",
    "AUG_BRAESS_MAX_TIME_STEP = int(8.0/AUG_BRAESS_TIME_STEP_LENGTH) + 1\n",
    "\n",
    "AUG_BRAESS_MFG_GAME = mean_field_routing_game.MeanFieldRoutingGame(\n",
    "    {\"max_num_time_step\": AUG_BRAESS_MAX_TIME_STEP, \"time_step_length\": AUG_BRAESS_TIME_STEP_LENGTH},\n",
    "    network=AUG_BRAESS_GRAPH, od_demand=AUG_BRAESS_OD_DEMAND)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Online Mirror Descent\n",
    "\n",
    "md_p_init = mirror_descent.MirrorDescent(AUG_BRAESS_MFG_GAME, lr=1)\n",
    "mfmd_timing, mfmd_policy, mfmd_nash_conv, mfmd_policy_value, md_p = online_mirror_descent(\n",
    "    AUG_BRAESS_MFG_GAME, 20, compute_metrics=True, return_policy=True, md_p=md_p_init)\n",
    "evolve_mean_field_game(AUG_BRAESS_MFG_GAME, mfmd_policy, AUG_BRAESS_GRAPH)\n",
    "\n",
    "print(f\"Online mirror descent nash conv: {mfmd_nash_conv.nash_conv()}\")\n",
    "print(f\"Online mirror descent timing: {mfmd_timing}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a name='pigou_deviation'></a>\n",
    "## 5. Average deviation of the mean field equilibrium policy in the N-player Pigou network game as a function of N.\n",
    "\n",
    "This is used to produce figure 3 of the AAMAS article."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_series_parallel_network(num_network_in_series,\n",
    "                                   time_step_length=1,\n",
    "                                   capacity=1):\n",
    "  i = 0\n",
    "  origin = \"A_0->B_0\"\n",
    "  graph_dict = {}\n",
    "  while i < num_network_in_series:\n",
    "    graph_dict.update({\n",
    "        f\"A_{i}\": {\n",
    "            \"connection\": {\n",
    "                f\"B_{i}\": {\n",
    "                    \"a\": 0,\n",
    "                    \"b\": 1.0,\n",
    "                    \"capacity\": capacity,\n",
    "                    \"free_flow_travel_time\": time_step_length\n",
    "                }\n",
    "            },\n",
    "            \"location\": [0 + 3 * i, 0]\n",
    "        },\n",
    "        f\"B_{i}\": {\n",
    "            \"connection\": {\n",
    "                f\"C_{i}\": {\n",
    "                    \"a\": 0.0,\n",
    "                    \"b\": 1.0,\n",
    "                    \"capacity\": capacity,\n",
    "                    \"free_flow_travel_time\": 2.0\n",
    "                },\n",
    "                f\"D_{i}\": {\n",
    "                    \"a\": 2.0,\n",
    "                    \"b\": 1.0,\n",
    "                    \"capacity\": capacity,\n",
    "                    \"free_flow_travel_time\": 1.0\n",
    "                }\n",
    "            },\n",
    "            \"location\": [1 + 3 * i, 0]\n",
    "        },\n",
    "        f\"C_{i}\": {\n",
    "            \"connection\": {\n",
    "                f\"A_{i+1}\": {\n",
    "                    \"a\": 0,\n",
    "                    \"b\": 1.0,\n",
    "                    \"capacity\": capacity,\n",
    "                    \"free_flow_travel_time\": time_step_length\n",
    "                }\n",
    "            },\n",
    "            \"location\": [2 + 3 * i, 1]\n",
    "        },\n",
    "        f\"D_{i}\": {\n",
    "            \"connection\": {\n",
    "                f\"A_{i+1}\": {\n",
    "                    \"a\": 0,\n",
    "                    \"b\": 1.0,\n",
    "                    \"capacity\": capacity,\n",
    "                    \"free_flow_travel_time\": time_step_length\n",
    "                }\n",
    "            },\n",
    "            \"location\": [2 + 3 * i, -1]\n",
    "        }\n",
    "    })\n",
    "    i += 1\n",
    "  graph_dict[f\"A_{i}\"] = {\n",
    "      \"connection\": {\n",
    "          \"END\": {\n",
    "              \"a\": 0,\n",
    "              \"b\": 1.0,\n",
    "              \"capacity\": capacity,\n",
    "              \"free_flow_travel_time\": time_step_length\n",
    "          }\n",
    "      },\n",
    "      \"location\": [0 + 3 * i, 0]\n",
    "  }\n",
    "  graph_dict[\"END\"] = {\"connection\": {}, \"location\": [1 + 3 * i, 0]}\n",
    "  time_horizon = int(5.0 * (num_network_in_series + 1) / time_step_length)\n",
    "  destination = f\"A_{i}->END\"\n",
    "  adjacency_list = {\n",
    "      key: list(value[\"connection\"].keys())\n",
    "      for key, value in graph_dict.items()\n",
    "  }\n",
    "  bpr_a_coefficient = {}\n",
    "  bpr_b_coefficient = {}\n",
    "  capacity = {}\n",
    "  free_flow_travel_time = {}\n",
    "  for o_node, value_dict in graph_dict.items():\n",
    "    for d_node, section_dict in value_dict[\"connection\"].items():\n",
    "      road_section = dynamic_routing_utils._nodes_to_road_section(\n",
    "          origin=o_node, destination=d_node)\n",
    "      bpr_a_coefficient[road_section] = section_dict[\"a\"]\n",
    "      bpr_b_coefficient[road_section] = section_dict[\"b\"]\n",
    "      capacity[road_section] = section_dict[\"capacity\"]\n",
    "      free_flow_travel_time[road_section] = section_dict[\n",
    "          \"free_flow_travel_time\"]\n",
    "  node_position = {key: value[\"location\"] for key, value in graph_dict.items()}\n",
    "  return dynamic_routing_utils.Network(\n",
    "      adjacency_list,\n",
    "      node_position=node_position,\n",
    "      bpr_a_coefficient=bpr_a_coefficient,\n",
    "      bpr_b_coefficient=bpr_b_coefficient,\n",
    "      capacity=capacity,\n",
    "      free_flow_travel_time=free_flow_travel_time\n",
    "  ), origin, destination, time_horizon"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class GoUp(PurePolicyResponse):\n",
    "\n",
    "  def pure_action(self, state):\n",
    "    location = state.get_current_vehicle_locations()[self.player_id].split(\n",
    "        \"->\")[1]\n",
    "    if location == \"B_0\":\n",
    "      return state.get_game().network.get_action_id_from_movement(\"B_0\", \"C_0\")\n",
    "    else:\n",
    "      return 0\n",
    "\n",
    "def compute_regret_policy_against_pure_policy_pigou_sim_game(game,\n",
    "                                                            policy,\n",
    "                                                            compute_true_value=False,\n",
    "                                                            num_sample=100):\n",
    "  time_tick = time.time()\n",
    "  if compute_true_value:\n",
    "    expected_value_policy = expected_game_score.policy_value(\n",
    "        game.new_initial_state(), policy)[0]\n",
    "  else:\n",
    "    expected_value_policy = get_expected_value_sim_game(game, policy, num_sample)\n",
    "  worse_regret = 0\n",
    "  deviation_policy = GoUp(game, policy, 0)\n",
    "  if compute_true_value:\n",
    "    expected_value_noise = expected_game_score.policy_value(\n",
    "        game.new_initial_state(), deviation_policy)[0]\n",
    "  else:\n",
    "    expected_value_noise = get_expected_value_sim_game(\n",
    "        game, deviation_policy, num_sample, player=0)\n",
    "  approximate_regret = expected_value_noise - expected_value_policy\n",
    "  worse_regret = max(worse_regret, approximate_regret)\n",
    "  return worse_regret, time.time() - time_tick"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_of_tests = 5\n",
    "\n",
    "computation_time_pure_policy_nash_conv_dict_large = {}\n",
    "pure_policy_nash_conv_n_player_dict_large = {}\n",
    "\n",
    "PIGOU_TIME_STEP_LENGTH = 0.05\n",
    "\n",
    "for pigou_num_vehicle in [x for x in range(1, 10, 1)] + [x for x in range(10, 100, 10)]:\n",
    "  PIGOU_GRAPH, PIGOU_ORIGIN, PIGOU_DESTINATION, PIGOU_MAX_TIME_STEP = create_series_parallel_network(\n",
    "  1, time_step_length=PIGOU_TIME_STEP_LENGTH, capacity=pigou_num_vehicle)\n",
    "\n",
    "  PIGOU_GAME, PIGOU_SEQ_GAME, PIGOU_MFG_GAME = create_games(\n",
    "        PIGOU_ORIGIN, PIGOU_DESTINATION, pigou_num_vehicle, PIGOU_GRAPH, PIGOU_MAX_TIME_STEP,\n",
    "        PIGOU_TIME_STEP_LENGTH)\n",
    "\n",
    "  md_p_init = mirror_descent.MirrorDescent(PIGOU_MFG_GAME, lr=1)\n",
    "  mfmd_timing, mfmd_policy, mfmd_nash_conv, mfmd_policy_value, md_p = online_mirror_descent(\n",
    "      PIGOU_MFG_GAME, 10, compute_metrics=True, return_policy=True, md_p=md_p_init)\n",
    "  print(f\"Online mirror descent nash conv: {mfmd_nash_conv.nash_conv()}\")\n",
    "  mfmd_policy_n_player_derived = dynamic_routing_to_mean_field_game.DerivedNPlayerPolicyFromMeanFieldPolicy(\n",
    "        PIGOU_GAME, mfmd_policy)\n",
    "\n",
    "  nash_conv_n_player_list = []\n",
    "  computation_time_list = []\n",
    "\n",
    "  # nash_conv_n_player, computation_time = compute_regret_policy_against_pure_policy_pigou_sim_game(\n",
    "  #   PIGOU_GAME, mfmd_policy_n_player_derived, compute_true_value=True)\n",
    "  for _ in range(num_of_tests):\n",
    "    nash_conv_n_player, computation_time = compute_regret_policy_against_pure_policy_pigou_sim_game(\n",
    "      PIGOU_GAME, mfmd_policy_n_player_derived, compute_true_value=False)\n",
    "    nash_conv_n_player_list.append(nash_conv_n_player)\n",
    "    computation_time_list.append(computation_time)\n",
    "    print(f\"Sampled exploitability: {nash_conv_n_player}, computed in {computation_time}\")\n",
    "  computation_time_pure_policy_nash_conv_dict_large[pigou_num_vehicle] = computation_time_list\n",
    "  pure_policy_nash_conv_n_player_dict_large[pigou_num_vehicle] = nash_conv_n_player_list\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy.special\n",
    "import matplotlib.pyplot as plt\n",
    "pigou_true_average_deviation_incentive = {}\n",
    "for num_player in range(1, 100):\n",
    "  probs = {}\n",
    "\n",
    "  for x in range(num_player):\n",
    "    probs[(x+1)/num_player] = scipy.special.binom(num_player-1, x)*(0.5**(num_player-1))\n",
    "\n",
    "  assert abs(sum(probs.values())-1) < 1e-4\n",
    "  e_tt = sum(p*(1.05+2*x) for x, p in probs.items())\n",
    "  pigou_true_average_deviation_incentive[num_player] = (e_tt-2.05)/2\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "plt.errorbar(\n",
    "    list(pure_policy_nash_conv_n_player_dict_large),\n",
    "    [sum(x)/len(x) for x in pure_policy_nash_conv_n_player_dict_large.values()],\n",
    "    yerr=[(max(x)-min(x))/2 for x in pure_policy_nash_conv_n_player_dict_large.values()], fmt='-xr', # ls='none',\n",
    "    label='Sampled') #  (mean, min and max, 100 sampled, 5 times)\n",
    "plt.plot(list(pigou_true_average_deviation_incentive), list(pigou_true_average_deviation_incentive.values()), '--', label='True Value')\n",
    "plt.legend()\n",
    "plt.xlabel('Number of players')\n",
    "plt.ylabel('Average deviation incentive')  # of mean field equilibrium policy\n",
    "plt.show()\n",
    "\n",
    "plt.plot(list(computation_time_pure_policy_nash_conv_dict_large), list([sum(x)/len(x) for x in computation_time_pure_policy_nash_conv_dict_large.values()]), label='Computation time sampled Nash conv')\n",
    "plt.legend()\n",
    "plt.xlabel('Number of players')\n",
    "plt.ylabel('Average deviation incentive computation time')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a name='braess_deviation'></a>\n",
    "## 6. Average deviation of the mean field equilibrium policy in the N-player Braess network game as a function of N.\n",
    "\n",
    "This is used to produce figure 6 of the AAMAS article."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy.special\n",
    "\n",
    "p_middle = 0.50\n",
    "p_up = 0.25\n",
    "p_down = 0.25\n",
    "prob_paths = {'up': 0.25, 'middle': 0.5, 'down': 0.25}\n",
    "\n",
    "time_step = 0.1\n",
    "average_deviation_incentive_braess = {}\n",
    "for num_other_player in range(1, 60):\n",
    "  # print(num_other_player)\n",
    "  def count_lien(location, volume):\n",
    "    if location == 'B->C' or location == 'D->E':\n",
    "      return 1 + volume/(num_other_player+1)\n",
    "    elif location == 'A->B' or new_location == 'E->F':\n",
    "      return 0\n",
    "    elif location == 'C->D':\n",
    "      return 0.25\n",
    "    elif location == 'B->D' or location == 'C->E':\n",
    "      return 2\n",
    "    raise ValueError()\n",
    "  probs_go_up = {}\n",
    "  probs_go_middle = {}\n",
    "  probs_each_path = {}\n",
    "\n",
    "  for x in range(num_other_player):\n",
    "    probs_go_up[x] = scipy.special.binom(num_other_player-1, x) * ((p_up+p_middle)**x) * ((p_down)**(num_other_player-1-x))\n",
    "    for y in range(num_other_player):\n",
    "      probs_go_middle[(y,x)] = scipy.special.binom(x, y) * ((p_middle/(p_up+p_middle))**y) * ((p_up/(p_up+p_middle))**(x-y))\n",
    "      if x-y >= 0:\n",
    "        probs_each_path[(x-y, y, num_other_player-x)] = probs_go_up[x] * probs_go_middle[(y,x)]\n",
    "\n",
    "  returns_per_policy = {}\n",
    "  for policy_tested in range(3):\n",
    "    returns = 0\n",
    "    for key in probs_each_path:\n",
    "      rewards = {}\n",
    "      # Do the simulation if the person was on path up\n",
    "      num_paths_up, num_paths_middle, num_paths_down = key\n",
    "      if policy_tested == 0:\n",
    "        path_taken = 'up'\n",
    "        num_paths_up += 1\n",
    "      if policy_tested == 1:\n",
    "        path_taken = 'middle'\n",
    "        num_paths_middle += 1\n",
    "      if policy_tested == 2:\n",
    "        path_taken = 'down'\n",
    "        num_paths_down += 1\n",
    "      states = {'A->B_up': 0.0, 'A->B_middlemilieu': 0.0, 'A->B_down': 0.0}\n",
    "      current_time_step = 0.0\n",
    "      while True:\n",
    "        min_waiting_time = min((x for x in states.items() if x[1]>0 or 'E->F' not in x[0]), key=lambda x: x[1])[1]\n",
    "        # print(min_waiting_time)\n",
    "        current_time_step += min_waiting_time\n",
    "        new_locations = {}\n",
    "        new_states = {}\n",
    "        for location_path, waiting_time in states.items():\n",
    "          location, path = location_path.split('_')\n",
    "          if path == 'up':\n",
    "            if waiting_time == min_waiting_time:\n",
    "              if location == 'A->B':\n",
    "                new_location = 'B->C'\n",
    "              elif location == 'B->C':\n",
    "                new_location = 'C->E'\n",
    "              elif location == 'C->E':\n",
    "                new_location = 'E->F'\n",
    "              elif location == 'E->F':\n",
    "                new_location = 'E->F'\n",
    "              else:\n",
    "                raise ValueError()\n",
    "              new_states[f\"{new_location}_up\"] = -1\n",
    "            else:\n",
    "              new_location = location\n",
    "              new_states[f\"{new_location}_uphaut\"] = waiting_time-min_waiting_time\n",
    "            if not new_location in new_locations:\n",
    "              new_locations[new_location] = 0\n",
    "            new_locations[new_location] += num_paths_up\n",
    "          elif path == 'middle':\n",
    "            if waiting_time == min_waiting_time:\n",
    "              if location == 'A->B':\n",
    "                new_location = 'B->C'\n",
    "              elif location == 'B->C':\n",
    "                new_location = 'C->D'\n",
    "              elif location == 'C->D':\n",
    "                new_location = 'D->E'\n",
    "              elif location == 'D->E':\n",
    "                new_location = 'E->F'\n",
    "              elif location == 'E->F':\n",
    "                new_location = 'E->F'\n",
    "              else:\n",
    "                raise ValueError()\n",
    "              new_states[f\"{new_location}_middle\"] = -1\n",
    "            else:\n",
    "              new_location = location\n",
    "              new_states[f\"{new_location}_middle\"] = waiting_time-min_waiting_time\n",
    "            if not new_location in new_locations:\n",
    "              new_locations[new_location] = 0\n",
    "            new_locations[new_location] += num_paths_middle\n",
    "          elif path == 'down':\n",
    "            if waiting_time == min_waiting_time:\n",
    "              if location == 'A->B':\n",
    "                new_location = 'B->D'\n",
    "              elif location == 'B->D':\n",
    "                new_location = 'D->E'\n",
    "              elif location == 'D->E':\n",
    "                new_location = 'E->F'\n",
    "              elif location == 'E->F':\n",
    "                new_location = 'E->F'\n",
    "              else:\n",
    "                raise ValueError()\n",
    "              new_states[f\"{new_location}_down\"] = -1\n",
    "            else:\n",
    "              new_location = location\n",
    "              new_states[f\"{new_location}_down\"] = waiting_time-min_waiting_time\n",
    "            if not new_location in new_locations:\n",
    "              new_locations[new_location] = 0\n",
    "            new_locations[new_location] += num_paths_down\n",
    "        should_stop = True\n",
    "        for location_path, waiting_time in new_states.items():\n",
    "          if location_path.split('_')[0] != 'E->F':\n",
    "            should_stop = False\n",
    "          else:\n",
    "            path = location_path.split('_')[1]\n",
    "            if path not in rewards:\n",
    "              rewards[path] = current_time_step\n",
    "          if waiting_time == -1:\n",
    "            new_location = location_path.split('_')[0]\n",
    "            new_states[location_path] = count_lien(new_location, new_locations[new_location])\n",
    "        states = new_states\n",
    "        if should_stop:\n",
    "          break\n",
    "      returns += probs_each_path[key] * rewards[path_taken]\n",
    "    returns_per_policy[path_taken] = returns\n",
    "  returns = 0\n",
    "  for k, v in returns_per_policy.items():\n",
    "    returns += v * prob_paths[k]\n",
    "  average_deviation_incentive_braess[num_other_player+1] = returns - min(returns_per_policy.values())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(list(average_deviation_incentive_braess), list(average_deviation_incentive_braess.values()), 'x', label='mean field policy in N player')\n",
    "plt.legend()\n",
    "# plt.title('Average deviation incentive of the mean field policy in the N player game as a function of N.')\n",
    "plt.xlabel('Number of players')\n",
    "plt.ylabel('Average deviation incentive')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
